{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读写`Tensor`\n",
    "\n",
    "我们可以直接使用`save`函数和`load`函数分别存储和读取`Tensor`。`save`使用Python的pickle实用程序将对象进行序列化，然后将序列化的对象保存到disk，使用`save`可以保存各种对象,包括模型、张量和字典等。而`laod`使用pickle unpickle工具将pickle的对象文件反序列化为内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "\n",
    "x = torch.ones(10,10)\n",
    "torch.save(x, 'torch_tensor.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "x2 = torch.load('torch_tensor.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "print(x2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存多个数值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.4742, -0.2547, -0.3533],\n",
      "        [ 0.5403, -0.4703, -0.7321],\n",
      "        [ 0.9841,  1.2643, -0.1577],\n",
      "        [ 1.2402,  0.4635,  0.5740],\n",
      "        [-1.1961, -1.0300,  0.1861],\n",
      "        [-0.4376, -0.2366,  0.9359],\n",
      "        [-0.3981,  1.8926,  0.2100],\n",
      "        [ 0.0125, -0.6135, -0.3395],\n",
      "        [ 0.8022,  0.1154,  0.0717],\n",
      "        [ 0.7155,  0.8362,  0.1046]])\n"
     ]
    }
   ],
   "source": [
    "y = torch.randn(10,3)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存 tensor list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]), tensor([[ 1.4742, -0.2547, -0.3533],\n",
      "        [ 0.5403, -0.4703, -0.7321],\n",
      "        [ 0.9841,  1.2643, -0.1577],\n",
      "        [ 1.2402,  0.4635,  0.5740],\n",
      "        [-1.1961, -1.0300,  0.1861],\n",
      "        [-0.4376, -0.2366,  0.9359],\n",
      "        [-0.3981,  1.8926,  0.2100],\n",
      "        [ 0.0125, -0.6135, -0.3395],\n",
      "        [ 0.8022,  0.1154,  0.0717],\n",
      "        [ 0.7155,  0.8362,  0.1046]])]\n"
     ]
    }
   ],
   "source": [
    "torch.save([x2,y], 'torch_tensor.pt')\n",
    "tensor_list = torch.load('torch_tensor.pt')\n",
    "print(tensor_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存 tensor dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),\n",
       " 'y': tensor([[ 1.4742, -0.2547, -0.3533],\n",
       "         [ 0.5403, -0.4703, -0.7321],\n",
       "         [ 0.9841,  1.2643, -0.1577],\n",
       "         [ 1.2402,  0.4635,  0.5740],\n",
       "         [-1.1961, -1.0300,  0.1861],\n",
       "         [-0.4376, -0.2366,  0.9359],\n",
       "         [-0.3981,  1.8926,  0.2100],\n",
       "         [ 0.0125, -0.6135, -0.3395],\n",
       "         [ 0.8022,  0.1154,  0.0717],\n",
       "         [ 0.7155,  0.8362,  0.1046]])}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.save({'x': x2, 'y': y}, 'xy_dict.pt')\n",
    "xy = torch.load('xy_dict.pt')\n",
    "xy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读写模型 保存读取Models 两种方式，保存参数数据，直接保存整个模型，包括结构。\n",
    "\n",
    "## State_dict\n",
    "在PyTorch中，`Module`的可学习参数(即权重和偏差)，模块模型包含在参数中(通过`model.parameters()`访问)。`state_dict`是一个从参数名称隐射到参数`Tesnor`的字典对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=16, out_features=8, bias=True)\n",
      "  (1): Linear(in_features=8, out_features=4, bias=True)\n",
      "  (2): Linear(in_features=4, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(16,8), nn.Linear(8,4), nn.Linear(4,1))\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('0.weight', tensor([[ 0.2331,  0.0276, -0.2451,  0.0109,  0.1566, -0.2014, -0.2156, -0.1280,\n",
      "          0.2056,  0.1407,  0.1083,  0.1882,  0.0551, -0.0400,  0.0443, -0.2372],\n",
      "        [ 0.1692,  0.1618, -0.0517, -0.0812, -0.1555, -0.0163,  0.0334,  0.2266,\n",
      "         -0.2469, -0.1984, -0.1384, -0.0199,  0.0136, -0.1724, -0.1590, -0.0053],\n",
      "        [-0.0704,  0.2221, -0.2207,  0.1495, -0.2040,  0.2320,  0.1979, -0.1726,\n",
      "          0.1717,  0.2442, -0.0479,  0.1695, -0.1683,  0.0465,  0.1016, -0.1294],\n",
      "        [-0.1667, -0.0417,  0.0753,  0.2427,  0.1194,  0.0642, -0.0650,  0.1371,\n",
      "         -0.0284,  0.1100,  0.1789,  0.0374,  0.1587, -0.0658,  0.2488,  0.0230],\n",
      "        [-0.0024, -0.1640, -0.0913, -0.0604,  0.0315,  0.1788, -0.0572,  0.2402,\n",
      "         -0.1028,  0.2186,  0.0836, -0.1235, -0.0566,  0.2199,  0.1654, -0.0090],\n",
      "        [-0.0714, -0.0216,  0.1099,  0.1452, -0.2453,  0.0840,  0.0402,  0.1697,\n",
      "         -0.1719,  0.2065, -0.1270, -0.1910, -0.1444,  0.1259, -0.0575,  0.1424],\n",
      "        [-0.2208, -0.0025, -0.2480,  0.1535, -0.2214, -0.2042,  0.2138, -0.1989,\n",
      "         -0.0068,  0.2189,  0.0176, -0.2460, -0.1771, -0.0226,  0.0893,  0.2168],\n",
      "        [ 0.1223,  0.0245,  0.0185,  0.0465,  0.1577,  0.0220,  0.2485,  0.2201,\n",
      "         -0.1498,  0.0866, -0.1646, -0.1211,  0.1691,  0.0807,  0.0808,  0.0860]])), ('0.bias', tensor([-0.0847,  0.1032,  0.0325,  0.0357, -0.0131, -0.2132, -0.1546,  0.1107])), ('1.weight', tensor([[-0.1156,  0.1067,  0.0955,  0.0615,  0.0132, -0.0102,  0.1994, -0.3311],\n",
      "        [-0.0215,  0.1875, -0.3116,  0.0448, -0.0347,  0.0608,  0.3463,  0.2103],\n",
      "        [-0.0192,  0.1606, -0.0482, -0.2978, -0.1651, -0.1022,  0.1929, -0.0749],\n",
      "        [-0.2968,  0.3489, -0.0084, -0.1591, -0.2460,  0.0155, -0.0495,  0.0225]])), ('1.bias', tensor([-0.3337,  0.3277,  0.2328,  0.2680])), ('2.weight', tensor([[ 0.3168, -0.2412, -0.1459,  0.1880]])), ('2.bias', tensor([-0.2969]))])\n",
      "tensor([[-0.4595],\n",
      "        [-0.3870],\n",
      "        [-0.3978],\n",
      "        [-0.3469],\n",
      "        [-0.6094],\n",
      "        [-0.3387],\n",
      "        [-0.4820],\n",
      "        [-0.5118],\n",
      "        [-0.6691],\n",
      "        [-0.5440]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(net.state_dict())\n",
    "x = torch.randn(10,16)\n",
    "out = net(x)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意，只有具有可学习参数的层(卷积层、线性层等)才有`state_dict`中的条目。优化器(`optim`)也有一个`state_dict`，其中包含关于优化器状态以及所使用的超参数的信息。\n",
    "\n",
    "```python\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "optimizer.state_dict()\n",
    "```\n",
    "\n",
    "输出："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2276782719624, 2276782808664, 2276782808736, 2276782808808, 2276782808880, 2276782808952]}]}\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "print(optimizer.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  保存和加载模型\n",
    "\n",
    "PyTorch中保存和加载训练模型有两种常见的方法:\n",
    "\n",
    "1. 仅保存和加载模型参数(`state_dict`)；\n",
    "2. 保存和加载整个模型。\n",
    "\n",
    "### 1. 保存和加载`state_dict`(推荐方式)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), 'linear_net_state_dict.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.4595],\n",
      "        [-0.3870],\n",
      "        [-0.3978],\n",
      "        [-0.3469],\n",
      "        [-0.6094],\n",
      "        [-0.3387],\n",
      "        [-0.4820],\n",
      "        [-0.5118],\n",
      "        [-0.6691],\n",
      "        [-0.5440]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "net2 = nn.Sequential(nn.Linear(16,8), nn.Linear(8,4), nn.Linear(4,1))\n",
    "net2.load_state_dict(torch.load('linear_net_state_dict.pt'))\n",
    "print(net2(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 保存和加载整个模型\n",
    "\n",
    "这种方式应该是模型的结构，梯度等数据也都保存了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.4595],\n",
      "        [-0.3870],\n",
      "        [-0.3978],\n",
      "        [-0.3469],\n",
      "        [-0.6094],\n",
      "        [-0.3387],\n",
      "        [-0.4820],\n",
      "        [-0.5118],\n",
      "        [-0.6691],\n",
      "        [-0.5440]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "torch.save(net2, 'linear_net_model.pt')\n",
    "\n",
    "#载入pt文件就载入了整个网络\n",
    "model = torch.load('linear_net_model.pt')\n",
    "print(model(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为这`net`和`net2`都有同样的模型参数，那么对同一个输入`X`的计算结果将会是一样的。上面的输出也验证了这一点。\n",
    "\n",
    "此外，还有一些其他使用场景，例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等，使用的时候可以参考[官方文档](https://pytorch.org/tutorials/beginner/saving_loading_models.html)。\n",
    "\n",
    "## 小结\n",
    "\n",
    "- 通过`save`函数和`load`函数可以很方便地读写`Tensor`。\n",
    "- 通过`save`函数和`load_state_dict`函数可以很方便地读写模型的参数。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
